{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-Wc92cWK-Aas"
   },
   "source": [
    "# Getting started with Owl-ViT\n",
    "\n",
    "In this notebook, we are going to run the [OWL-ViT](https://arxiv.org/abs/2205.06230) model (an open-vocabulary object detection model) by Google Research on scikit-image samples images. \n",
    "\n",
    "## OWL-ViT: A Quick Intro\n",
    "OWL-ViT is an open-vocabulary object detector. Given an image and one or multiple free-text queries, it finds objects matching the queries in the image. Unlike traditional object detection models, OWL-ViT is not trained on labeled object datasets and leverages multi-modal representations to perform open-vocabulary detection. \n",
    "\n",
    "OWL-ViT uses CLIP with a ViT-like Transformer as its backbone to get multi-modal visual and text features. To use CLIP for object detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a lightweight classification and box head to each transformer output token. Open-vocabulary classification is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image can be used to perform zero-shot text-conditioned object detection.\n",
    "\n",
    "![owlvit architecture](https://raw.githubusercontent.com/google-research/scenic/a41d24676f64a2158bfcd7cb79b0a87673aa875b/scenic/projects/owl_vit/data/owl_vit_schematic.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uIcaig48T6yv"
   },
   "source": [
    "## Set-up environment\n",
    "\n",
    "First, we install the HuggingFace Transformers library (from source for now, as the model was recently added to the library and is under active development). This might take a few minutes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "_XLma_DL3S9-",
    "outputId": "3d8e2639-90a0-4820-bf7f-04eee42e5258"
   },
   "outputs": [],
   "source": [
    "!pip install -q git+https://github.com/huggingface/transformers.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Optional:** Install Pillow, matplotlib and OpenCV if you are running this notebook locally."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install Pillow\n",
    "!pip install matplotlib\n",
    "!pip install opencv-python"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QPrCVnimE0qR"
   },
   "source": [
    "## Load pre-trained model and processor\n",
    "\n",
    "Let's first apply the image preprocessing and tokenize the text queries using `OwlViTProcessor`. The processor will resize the image(s), scale it between [0-1] range and normalize it across the channels using the mean and standard deviation specified in the original codebase.\n",
    "\n",
    "\n",
    "Text queries are tokenized using a CLIP tokenizer and stacked to output tensors of shape [batch_size * num_max_text_queries, sequence_length]. If you are inputting more than one set of (image, text prompt/s), num_max_text_queries is the maximum number of text queries per image across the batch. Input samples with fewer text queries are padded. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "AD8DXCnJ7faH",
    "outputId": "3207f732-be55-46a2-ce3e-346e32efeac9"
   },
   "outputs": [],
   "source": [
    "from transformers import OwlViTProcessor, OwlViTForObjectDetection\n",
    "\n",
    "model = OwlViTForObjectDetection.from_pretrained(\"google/owlvit-base-patch32\")\n",
    "processor = OwlViTProcessor.from_pretrained(\"google/owlvit-base-patch32\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7QN-vURe3euV"
   },
   "source": [
    "## Preprocess input image and text queries\n",
    "\n",
    "Let's use the image of astronaut Eileen Collins to test OWL-ViT. It's part of the [NASA](https://www.nasa.gov/multimedia/imagegallery/index.html) Great Images dataset.\n",
    "\n",
    "You can use one or multiple text prompts per image to search for the target object(s). Let's start with a simple example where we search for multiple objects in a single image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 529
    },
    "id": "LOX1__3nrezW",
    "outputId": "4435e2c9-b79d-4aaf-d7ef-b1cf2c239d01"
   },
   "outputs": [],
   "source": [
    "import cv2\n",
    "import skimage\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "# Download sample image\n",
    "image = skimage.data.astronaut()\n",
    "image = Image.fromarray(np.uint8(image)).convert(\"RGB\")\n",
    "\n",
    "# Text queries to search the image for\n",
    "text_queries = [\"human face\", \"rocket\", \"nasa badge\", \"star-spangled banner\"]\n",
    "\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_Fu_0bM2Mz0R"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "# Use GPU if available\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "S6lQ2d9uodiv",
    "outputId": "f5634352-d44e-47c4-f66c-4ad51522efdb"
   },
   "outputs": [],
   "source": [
    "# Process image and text inputs\n",
    "inputs = processor(text=text_queries, images=image, return_tensors=\"pt\").to(device)\n",
    "\n",
    "# Print input names and shapes\n",
    "for key, val in inputs.items():\n",
    "    print(f\"{key}: {val.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ten5ZJQsoUbE"
   },
   "source": [
    "## Forward pass\n",
    "\n",
    "Now we can pass the inputs to our OWL-ViT model to get object detection predictions. \n",
    "\n",
    "`OwlViTForObjectDetection` model outputs the prediction logits, boundary boxes and class embeddings, along with the image and text embeddings outputted by the `OwlViTModel`, which is the CLIP backbone."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "t7oKBoKQEz_w",
    "outputId": "5deb56c6-c302-4175-a854-d4d6e7b7d903"
   },
   "outputs": [],
   "source": [
    "# Set model in evaluation mode\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "\n",
    "# Get predictions\n",
    "with torch.no_grad():\n",
    "  outputs = model(**inputs)\n",
    "\n",
    "for k, val in outputs.items():\n",
    "    if k not in {\"text_model_output\", \"vision_model_output\"}:\n",
    "        print(f\"{k}: shape of {val.shape}\")\n",
    "\n",
    "print(\"\\nText model outputs\")\n",
    "for k, val in outputs.text_model_output.items():\n",
    "    print(f\"{k}: shape of {val.shape}\")\n",
    "\n",
    "print(\"\\nVision model outputs\")\n",
    "for k, val in outputs.vision_model_output.items():\n",
    "    print(f\"{k}: shape of {val.shape}\") "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "h2C5dkQ8sBVZ"
   },
   "source": [
    "## Draw predictions on image\n",
    "\n",
    "Let's draw the predictions / found objects on the input image. Remember the found objects correspond to the input text queries. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PNYQe2xbl5vl"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from transformers.image_utils import ImageFeatureExtractionMixin\n",
    "mixin = ImageFeatureExtractionMixin()\n",
    "\n",
    "# Load example image\n",
    "image_size = model.config.vision_config.image_size\n",
    "image = mixin.resize(image, image_size)\n",
    "input_image = np.asarray(image).astype(np.float32) / 255.0\n",
    "\n",
    "# Threshold to eliminate low probability predictions\n",
    "score_threshold = 0.1\n",
    "\n",
    "# Get prediction logits\n",
    "logits = torch.max(outputs[\"logits\"][0], dim=-1)\n",
    "scores = torch.sigmoid(logits.values).cpu().detach().numpy()\n",
    "\n",
    "# Get prediction labels and boundary boxes\n",
    "labels = logits.indices.cpu().detach().numpy()\n",
    "boxes = outputs[\"pred_boxes\"][0].cpu().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 480
    },
    "id": "Xc6YjWXBl_vK",
    "outputId": "23ed2b75-59f0-461b-ce8d-48574de96f92"
   },
   "outputs": [],
   "source": [
    "def plot_predictions(input_image, text_queries, scores, boxes, labels):\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(8, 8))\n",
    "    ax.imshow(input_image, extent=(0, 1, 1, 0))\n",
    "    ax.set_axis_off()\n",
    "\n",
    "    for score, box, label in zip(scores, boxes, labels):\n",
    "      if score < score_threshold:\n",
    "        continue\n",
    "\n",
    "      cx, cy, w, h = box\n",
    "      ax.plot([cx-w/2, cx+w/2, cx+w/2, cx-w/2, cx-w/2],\n",
    "              [cy-h/2, cy-h/2, cy+h/2, cy+h/2, cy-h/2], \"r\")\n",
    "      ax.text(\n",
    "          cx - w / 2,\n",
    "          cy + h / 2 + 0.015,\n",
    "          f\"{text_queries[label]}: {score:1.2f}\",\n",
    "          ha=\"left\",\n",
    "          va=\"top\",\n",
    "          color=\"red\",\n",
    "          bbox={\n",
    "              \"facecolor\": \"white\",\n",
    "              \"edgecolor\": \"red\",\n",
    "              \"boxstyle\": \"square,pad=.3\"\n",
    "          })\n",
    "    \n",
    "plot_predictions(input_image, text_queries, scores, boxes, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WBgWR3yONIwT"
   },
   "source": [
    "## Batch processing\n",
    "We can also pass in multiple sets of images and text queries to search for different (or same) objects in different images. Let's download an image of a coffee mug to process alongside the astronaut image.\n",
    "\n",
    "For batch processing, we need to input text queries as a nested list to `OwlViTProcessor` and images as lists of (PIL images or PyTorch tensors or NumPy arrays)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 417
    },
    "id": "fZJrboQlswuA",
    "outputId": "ac0c0d96-f84f-4938-dadc-9a80669cb5a7"
   },
   "outputs": [],
   "source": [
    "# Download the coffee mug image\n",
    "image = skimage.data.coffee()\n",
    "image = Image.fromarray(np.uint8(image)).convert(\"RGB\")\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NVLvC8IrNPQW",
    "outputId": "a0c49462-2137-4a02-8971-f40765ef35fa"
   },
   "outputs": [],
   "source": [
    "# Preprocessing\n",
    "images = [skimage.data.astronaut(), skimage.data.coffee()]\n",
    "images = [Image.fromarray(np.uint8(img)).convert(\"RGB\") for img in images]\n",
    "\n",
    "# Nexted list of text queries to search each image for\n",
    "text_queries = [[\"human face\", \"rocket\", \"nasa badge\", \"star-spangled banner\"], [\"coffee mug\", \"spoon\", \"plate\"]]\n",
    "\n",
    "# Process image and text inputs\n",
    "inputs = processor(text=text_queries, images=images, return_tensors=\"pt\").to(device)\n",
    "\n",
    "# Print input names and shapes\n",
    "for key, val in inputs.items():\n",
    "    print(f\"{key}: {val.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "i9Q2ik1kNZ1j"
   },
   "source": [
    "**Note:** Notice the size of the `input_ids `and `attention_mask` is `[batch_size * num_max_text_queries, max_length]`. Max_length is set to 16 for all OWL-ViT models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ApJXe-2sNR9-",
    "outputId": "61756193-4fae-4d9d-90a2-13da26ae185e"
   },
   "outputs": [],
   "source": [
    "# Get predictions\n",
    "with torch.no_grad():\n",
    "  outputs = model(**inputs)\n",
    "\n",
    "for k, val in outputs.items():\n",
    "    if k not in {\"text_model_output\", \"vision_model_output\"}:\n",
    "        print(f\"{k}: shape of {val.shape}\")\n",
    "        \n",
    "print(\"\\nText model outputs\")\n",
    "for k, val in outputs.text_model_output.items():\n",
    "    print(f\"{k}: shape of {val.shape}\")\n",
    "\n",
    "print(\"\\nVision model outputs\")\n",
    "for k, val in outputs.vision_model_output.items():\n",
    "    print(f\"{k}: shape of {val.shape}\") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's plot the predictions for the second image\n",
    "image_idx = 1\n",
    "image_size = model.config.vision_config.image_size\n",
    "image = mixin.resize(images[image_idx], image_size)\n",
    "input_image = np.asarray(image).astype(np.float32) / 255.0\n",
    "\n",
    "# Threshold to eliminate low probability predictions\n",
    "score_threshold = 0.1\n",
    "\n",
    "# Get prediction logits\n",
    "logits = torch.max(outputs[\"logits\"][image_idx], dim=-1)\n",
    "scores = torch.sigmoid(logits.values).cpu().detach().numpy()\n",
    "\n",
    "# Get prediction labels and boundary boxes\n",
    "labels = logits.indices.cpu().detach().numpy()\n",
    "boxes = outputs[\"pred_boxes\"][image_idx].cpu().detach().numpy()\n",
    "\n",
    "plot_predictions(input_image, text_queries[image_idx], scores, boxes, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "n6VYaXGbNfwU"
   },
   "source": [
    "## Post-processing model predictions\n",
    "Notice how we printed the output predictions on the resized input image. This is because OWL-ViT outputs normalized box coordinates in `[cx, cy, w, h]` format assuming a fixed input image size. We can use the `OwlViTProcessor`'s convenient post_process() method to convert the model outputs to a COCO API format and retrieve rescaled coordinates (with respect to the original image sizes) in `[x0, y0, x1, y1]` format. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "R3xneq28NcUm",
    "outputId": "802a1707-e309-4c96-88bd-a8882afee8d9"
   },
   "outputs": [],
   "source": [
    "# Target image sizes (height, width) to rescale box predictions [batch_size, 2]\n",
    "target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device)\n",
    "\n",
    "# Convert outputs (bounding boxes and class logits) to COCO API\n",
    "results = processor.post_process(outputs=outputs, target_sizes=target_sizes)\n",
    "\n",
    "# Loop over predictions for each image in the batch\n",
    "for i in range(len(images)):\n",
    "    print(f\"\\nProcessing image {i}\")\n",
    "    text = text_queries[i]\n",
    "    boxes, scores, labels = results[i][\"boxes\"], results[i][\"scores\"], results[i][\"labels\"]\n",
    "\n",
    "    score_threshold = 0.1\n",
    "    for box, score, label in zip(boxes, scores, labels):\n",
    "        box = [round(i, 2) for i in box.tolist()]\n",
    "\n",
    "        if score >= score_threshold:\n",
    "            print(f\"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bonus: one-shot / image-guided object detection\n",
    "Instead of performing zero-shot detection with text inputs, we can use the `OwlViTForObjectDetection.image_guided_detection()` method to query an input image with a query / example image and detect similar looking objects. To do this, we simply pass in `query_images` instead of text to the processor to get the `query_pixel_values`. Note that, unlike text input, `OwlViTProcessor` expects one query image per target image we'd like to query for similar objects. We will also see that the output and post-processing of one-shot object detection is very similar to the zero-shot / text-guided detection.\n",
    "\n",
    "Let's try this out by querying an image with cats with another random cat image. For this part of the demo, we will perform image-guided object detection, post-process the results and display the predicted boundary boxes on the original input image using OpenCV."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import requests\n",
    "from matplotlib import rcParams\n",
    "\n",
    "# Set figure size\n",
    "%matplotlib inline\n",
    "rcParams['figure.figsize'] = 11 ,8\n",
    "\n",
    "# Input image\n",
    "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
    "image = Image.open(requests.get(url, stream=True).raw)\n",
    "target_sizes = torch.Tensor([image.size[::-1]])\n",
    "\n",
    "# Query image\n",
    "query_url = \"http://images.cocodataset.org/val2017/000000058111.jpg\"\n",
    "query_image = Image.open(requests.get(query_url, stream=True).raw)\n",
    "\n",
    "# Display input image and query image\n",
    "fig, ax = plt.subplots(1,2)\n",
    "ax[0].imshow(image)\n",
    "ax[1].imshow(query_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process input and query image\n",
    "inputs = processor(images=image, query_images=query_image, return_tensors=\"pt\").to(device)\n",
    "\n",
    "# Print input names and shapes\n",
    "for key, val in inputs.items():\n",
    "    print(f\"{key}: {val.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get predictions\n",
    "with torch.no_grad():\n",
    "  outputs = model.image_guided_detection(**inputs)\n",
    "\n",
    "for k, val in outputs.items():\n",
    "    if k not in {\"text_model_output\", \"vision_model_output\"}:\n",
    "        print(f\"{k}: shape of {val.shape}\")\n",
    "\n",
    "print(\"\\nVision model outputs\")\n",
    "for k, val in outputs.vision_model_output.items():\n",
    "    print(f\"{k}: shape of {val.shape}\") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)\n",
    "outputs.logits = outputs.logits.cpu()\n",
    "outputs.target_pred_boxes = outputs.target_pred_boxes.cpu() \n",
    "\n",
    "results = processor.post_process_image_guided_detection(outputs=outputs, threshold=0.6, nms_threshold=0.3, target_sizes=target_sizes)\n",
    "boxes, scores = results[0][\"boxes\"], results[0][\"scores\"]\n",
    "\n",
    "# Draw predicted bounding boxes\n",
    "for box, score in zip(boxes, scores):\n",
    "    box = [int(i) for i in box.tolist()]\n",
    "\n",
    "    img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)\n",
    "    if box[3] + 25 > 768:\n",
    "        y = box[3] - 10\n",
    "    else:\n",
    "        y = box[3] + 25 \n",
    "        \n",
    "plt.imshow(img[:,:,::-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "OWL-ViT-inference example.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
